Linearization Pipeline
Linearization Pipeline¶
%reload_ext autoreload
%autoreload 2
This pipeline takes 2D position data from the IntervalPositionInfo table and "linearizes" it to 1D position.
1. Retrieving 2D position from the IntervalPositionInfo table¶
To retrieve 2D position data, we need to specify a nwb file, an position time interval, and the set of parameters used to compute the position info.
First we specify the nwb file and get the name of the copied version (without the ephys data) that is used as a key.
from spyglass.common.nwb_helper_fn import get_nwb_copy_filename
import spyglass as nd
nwb_file_name = "chimi20200216_new.nwb"
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)
nwb_copy_file_name
[2022-08-04 16:16:35,902][INFO]: Connecting zoldello@lmf-db.cin.ucsf.edu:3306 [2022-08-04 16:16:35,953][INFO]: Connected zoldello@lmf-db.cin.ucsf.edu:3306
/home/zoldello/anaconda3/envs/spyglass/lib/python3.9/site-packages/position_tools/core.py:3: DeprecationWarning: Please use `gaussian_filter1d` from the `scipy.ndimage` namespace, the `scipy.ndimage.filters` namespace is deprecated. from scipy.ndimage.filters import gaussian_filter1d
'chimi20200216_new_.nwb'
Now we can specify which interval we want to look at (and which parameters we used if we've run it with more than one set of parameters).
We will fetch the pandas dataframe from the IntervalPositionInfo table for easy plotting.
(We probably want a way to specify what type of position we are using in case we use deep lab cut...)
from spyglass.common.common_position import IntervalPositionInfo
position_info = (
IntervalPositionInfo()
& {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"position_info_param_name": "default",
}
).fetch1_dataframe()
position_info
/stelmo/nwb/analysis/chimi20200216_new_ZF24K2JHC1.nwb
| head_position_x | head_position_y | head_orientation | head_velocity_x | head_velocity_y | head_speed | |
|---|---|---|---|---|---|---|
| time | ||||||
| 1.581887e+09 | 91.051650 | 211.127050 | 2.999696 | 1.387074 | 2.848838 | 3.168573 |
| 1.581887e+09 | 90.844337 | 211.417287 | 3.078386 | 3.123201 | 3.411111 | 4.624939 |
| 1.581887e+09 | 90.637025 | 211.707525 | -3.114572 | 5.431643 | 4.089597 | 6.799085 |
| 1.581887e+09 | 90.802875 | 211.596958 | -3.033109 | 8.097753 | 4.979262 | 9.506138 |
| 1.581887e+09 | 91.288579 | 211.482443 | -3.062550 | 10.840482 | 6.071373 | 12.424880 |
| ... | ... | ... | ... | ... | ... | ... |
| 1.581888e+09 | 182.158583 | 201.452467 | -0.986926 | 0.348276 | 0.218575 | 0.411182 |
| 1.581888e+09 | 182.158583 | 201.397183 | -0.978610 | 0.279135 | -0.058413 | 0.285182 |
| 1.581888e+09 | 182.213867 | 201.341900 | -0.957589 | 0.193798 | -0.283200 | 0.343162 |
| 1.581888e+09 | 182.158583 | 201.341900 | -0.970083 | 0.110838 | -0.417380 | 0.431846 |
| 1.581888e+09 | 182.158583 | 201.286617 | -0.936414 | 0.045190 | -0.453966 | 0.456209 |
39340 rows × 6 columns
Let's linearize the head position. We will plot the head position to get a sense of the data.
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
color="lightgrey",
)
ax.set_xlabel("x-position [cm]", fontsize=18)
ax.set_ylabel("y-position [cm]", fontsize=18)
ax.set_title("Head Position", fontsize=28)
Text(0.5, 1.0, 'Head Position')
2. Specifying the track graph for linearization¶
In order to linearize the data, we must build a graph of nodes and edges that specify the geometry of the track in 1D and 2D. This will be referred to as the TrackGraph.
We do this by specifying four variables:
node_positionsare the 2D positions of the graph (in cm)edgesspecify how the nodes are connected. Each edge consists of a pair of nodes. Each node is labeled by their respective index innode_positions. For example, (79.910, 216.720) is the 2D position of node 0 and (183.784, 45.375) is the 2D position of node 8. So specifying (0, 8) means there is an edge between node 0 and node 8linear_edge_orderspecifies how the edges are laid out in linear space in order. As before, each edge consists of a pair of nodes, which are labeled by their index. The order of the nodes controls their order in 1D space. So edge (0, 1) connects node 0 and node 1 and node 0 will be placed in linear position space before node 1. Specifying edge (1, 0) would reverse the linear positions for that edge.linear_edge_spacingspecifies the spacing between each edge. This can either by a single number or be an array the length of the number of gaps between edges. If it is a single number, all edges will have a gap between them (15 cm in this example). If it is an array, then the spacing between edges can be individually controlled. You may want to have gaps between edges if they are not spatially connected in 2D space.
For more examples, see this notebook: https://github.com/LorenFrankLab/track_linearization/blob/master/notebooks/
import numpy as np
node_positions = np.array(
[
(79.910, 216.720), # top left well 0
(132.031, 187.806), # top middle intersection 1
(183.718, 217.713), # top right well 2
(132.544, 132.158), # middle intersection 3
(87.202, 101.397), # bottom left intersection 4
(31.340, 126.110), # middle left well 5
(180.337, 104.799), # middle right intersection 6
(92.693, 42.345), # bottom left well 7
(183.784, 45.375), # bottom right well 8
(231.338, 136.281), # middle right well 9
]
)
edges = np.array(
[
(0, 1),
(1, 2),
(1, 3),
(3, 4),
(4, 5),
(3, 6),
(6, 9),
(4, 7),
(6, 8),
]
)
linear_edge_order = [
(3, 6),
(6, 8),
(6, 9),
(3, 1),
(1, 2),
(1, 0),
(3, 4),
(4, 5),
(4, 7),
]
linear_edge_spacing = 15
Once we have these variables, we must name the track graph (track_graph_name). We also can specify the environment it corresponds to (environment)
from spyglass.common.common_position import TrackGraph
TrackGraph.insert1(
{
"track_graph_name": "6 arm",
"environment": "6 arm",
"node_positions": node_positions,
"edges": edges,
"linear_edge_order": linear_edge_order,
"linear_edge_spacing": linear_edge_spacing,
},
skip_duplicates=True,
)
graph = TrackGraph() & {"track_graph_name": "6 arm"}
graph
| track_graph_name | environment Type of Environment | node_positions 2D position of track_graph nodes, shape (n_nodes, 2) | edges shape (n_edges, 2) | linear_edge_order order of track graph edges in the linear space, shape (n_edges, 2) | linear_edge_spacing amount of space between edges in the linear space, shape (n_edges,) | linear_edge_specialty denote what edges (denote by 5) are going to be lumped to what edge (denote by 1), shape (n_edges,) |
|---|---|---|---|---|---|---|
| 6 arm | 6 arm | =BLOB= | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
Total: 1
The TrackGraph has several convenient methods for visualizing the graph in 2D and 1D space. Here we use the method plot_track_graph to plot the graph in 2D. Notice we give it the name of the track_graph. It is important to plot the track graph in 2D over the position to make sure our layout makes sense.
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
color="lightgrey",
alpha=0.7,
zorder=0,
)
ax.set_xlabel("x-position [cm]", fontsize=18)
ax.set_ylabel("y-position [cm]", fontsize=18)
graph.plot_track_graph(ax=ax)
We can also look at how this will translate to 1D space by using the plot_track_graph_as_1D method
fig, ax = plt.subplots(1, 1, figsize=(20, 1))
graph.plot_track_graph_as_1D(ax=ax)
3. Setting up the parameters for linearization¶
There are several other parameters we can set for linearization. They are only relevant if you choose to use the HMM method of linearization.
By default, linearization assigns each 2D position to its nearest point on the track graph. This is then translated into 1D space.
If use_hmm is selected, then an HMM is used to assign these points. The HMM can be useful because it takes into account the prior position and edge the animal is on. This can keep the position from suddenly jumping to another edge such as in the case of an intersection or if the reward wells are close to each other and the animal's head position swings closer to the other reward well (even though it is physically on another edge of the track).
from spyglass.common.common_position import LinearizationParameters
LinearizationParameters.insert1(
{"linearization_param_name": "default"}, skip_duplicates=True
)
LinearizationParameters()
| linearization_param_name name for this set of parameters | use_hmm use HMM to determine linearization | route_euclidean_distance_scaling How much to prefer route distances between successive time points that are closer to the euclidean distance. Smaller numbers mean the route distance is more likely to be close to the euclidean distance. | sensor_std_dev Uncertainty of position sensor (in cm). | diagonal_bias Biases the transition matrix to prefer the current track segment. |
|---|---|---|---|---|
| default | 0 | 1.0 | 5.0 | 0.5 |
Total: 1
Once we have some linearization parameters, like with the 2D position, we specify the corresponding position interval we wish to use those parameters with.
from spyglass.common.common_position import IntervalLinearizationSelection
IntervalLinearizationSelection.insert1(
{
"position_info_param_name": "default",
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"track_graph_name": "6 arm",
"linearization_param_name": "default",
},
skip_duplicates=True,
)
IntervalLinearizationSelection()
| position_info_param_name name for this set of parameters | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | track_graph_name | linearization_param_name name for this set of parameters |
|---|---|---|---|---|
| default | chimi20200216_new_.nwb | pos 1 valid times | 6 arm | default |
Total: 1
And then we can run the linearization by populating the IntervalLinearizedPosition table.
from spyglass.common.common_position import IntervalLinearizedPosition
IntervalLinearizedPosition().populate()
IntervalLinearizedPosition()
Computing linear position for: {'position_info_param_name': 'default', 'nwb_file_name': 'chimi20200216_new_.nwb', 'interval_list_name': 'pos 1 valid times', 'track_graph_name': '6 arm', 'linearization_param_name': 'default'}
Writing new NWB file chimi20200216_new_GBGCXYMIWB.nwb
| position_info_param_name name for this set of parameters | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | track_graph_name | linearization_param_name name for this set of parameters | analysis_file_name name of the file | linearized_position_object_id |
|---|---|---|---|---|---|---|
| default | chimi20200216_new_.nwb | pos 1 valid times | 6 arm | default | chimi20200216_new_GBGCXYMIWB.nwb | 8d132da2-c1e4-402f-ba3a-4b8725a6c87a |
Total: 1
4. Examining the data¶
After populating the table, we can use the fetch1_dataframe method to retreive the linear position data.
The dataframe has several variables:
linear_positionis the 1D linearized positiontrack_segment_idis the index number of the edges given to track graphprojected_x_position,projected_y_positionis the 2D position projected to the track graph
Time is set as the index of the dataframe
linear_position_df = (
IntervalLinearizedPosition()
& {
"position_info_param_name": "default",
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
"track_graph_name": "6 arm",
"linearization_param_name": "default",
}
).fetch1_dataframe()
linear_position_df
| linear_position | track_segment_id | projected_x_position | projected_y_position | |
|---|---|---|---|---|
| time | ||||
| 1.581887e+09 | 412.042773 | 0 | 90.802281 | 210.677533 |
| 1.581887e+09 | 412.364853 | 0 | 90.520636 | 210.833775 |
| 1.581887e+09 | 412.686934 | 0 | 90.238990 | 210.990018 |
| 1.581887e+09 | 412.488270 | 0 | 90.412714 | 210.893645 |
| 1.581887e+09 | 412.007991 | 0 | 90.832697 | 210.660660 |
| ... | ... | ... | ... | ... |
| 1.581888e+09 | 340.401589 | 1 | 175.500994 | 212.958497 |
| 1.581888e+09 | 340.373902 | 1 | 175.477029 | 212.944630 |
| 1.581888e+09 | 340.394065 | 1 | 175.494481 | 212.954729 |
| 1.581888e+09 | 340.346214 | 1 | 175.453064 | 212.930764 |
| 1.581888e+09 | 340.318527 | 1 | 175.429100 | 212.916898 |
39340 rows × 4 columns
Let's plot the linearized position over time colored by the particular edge. As a reference, we can put the 1D layout of the track graph on the y-axis.
fig, ax = plt.subplots(figsize=(20, 13))
ax.scatter(
linear_position_df.index,
linear_position_df.linear_position,
c=linear_position_df.track_segment_id,
s=1,
)
graph.plot_track_graph_as_1D(
ax=ax, axis="y", other_axis_start=linear_position_df.index[-1] + 10
)
ax.set_xlabel("Time [s]", fontsize=18)
ax.set_ylabel("Linear Position [cm]", fontsize=18)
ax.set_title("Linear Position", fontsize=28)
Text(0.5, 1.0, 'Linear Position')
We can also plot the 2D position projected to the track graph
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.plot(
position_info.head_position_x,
position_info.head_position_y,
color="lightgrey",
alpha=0.7,
zorder=0,
)
ax.set_xlabel("x-position [cm]", fontsize=18)
ax.set_ylabel("y-position [cm]", fontsize=18)
ax.plot(
linear_position_df.projected_x_position,
linear_position_df.projected_y_position,
)
[<matplotlib.lines.Line2D at 0x7f5d0807e760>]
5. Interactively selecting the track graph nodes and edges [Work in Progress]¶
NodePicker¶
The linearization heavily depends on how you specify the track graph. Setting the node positions and edges can be diffcult. To help simplify this process, we can use the NodePicker to interactively set the node positions and edges based on the video of the track.
%matplotlib widget
from spyglass.common.common_position import NodePicker
import pynwb
key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 1 valid times",
}
epoch = (
int(
key["interval_list_name"]
.replace("pos ", "")
.replace(" valid times", "")
)
+ 1
)
video_info = (
nd.common.common_behav.VideoFile()
& {"nwb_file_name": key["nwb_file_name"], "epoch": epoch}
).fetch1()
io = pynwb.NWBHDF5IO("/stelmo/nwb/raw/" + video_info["nwb_file_name"], "r")
nwb_file = io.read()
nwb_video = nwb_file.objects[video_info["video_file_object_id"]]
video_filename = nwb_video.external_file.value[0]
fig, ax = plt.subplots(figsize=(8, 8))
picker = NodePicker(ax=ax, video_filename=video_filename)
Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …
Once the nodes and edges have been selected, we can retrieve them using the node_positions and edges attributes
picker.node_positions
array([], dtype=float64)
picker.edges
[[]]
Selector¶
We can also draw a 2d polygon around the track and attempt to recover the graph.
%matplotlib widget
from spyglass.common.common_position import SelectFromCollection
fig, ax = plt.subplots(figsize=(8, 8))
selector = SelectFromCollection(ax, video_filename)
print("Select points in the figure by enclosing them within a polygon.")
print("Press the 'esc' key to start a new polygon.")
print("Try holding the 'shift' key to move all of the vertices.")
print("Try holding the 'ctrl' key to move a single vertex.")
Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …
Select points in the figure by enclosing them within a polygon. Press the 'esc' key to start a new polygon. Try holding the 'shift' key to move all of the vertices. Try holding the 'ctrl' key to move a single vertex.